{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Subword Embedding\n",
    ":label:`sec_fasttext`\n",
    "\n",
    "English words usually have internal structures and formation methods. For example, we can deduce the relationship between \"dog\", \"dogs\", and \"dogcatcher\" by their spelling. All these words have the same root, \"dog\", but they use different suffixes to change the meaning of the word. Moreover, this association can be extended to other words. For example, the relationship between \"dog\" and \"dogs\" is just like the relationship between \"cat\" and \"cats\". The relationship between \"boy\" and \"boyfriend\" is just like the relationship between \"girl\" and \"girlfriend\". This characteristic is not unique to English. In French and Spanish, a lot of verbs can have more than 40 different forms depending on the context. In Finnish, a noun may have more than 15 forms. In fact, morphology, which is an important branch of linguistics, studies the internal structure and formation of words.\n",
    "\n",
    "\n",
    "## fastText\n",
    "\n",
    "In word2vec, we did not directly use morphology information.  In both the\n",
    "skip-gram model and continuous bag-of-words model, we use different vectors to\n",
    "represent words with different forms. For example, \"dog\" and \"dogs\" are\n",
    "represented by two different vectors, while the relationship between these two\n",
    "vectors is not directly represented in the model. In view of this, fastText :cite:`Bojanowski.Grave.Joulin.ea.2017`\n",
    "proposes the method of subword embedding, thereby attempting to introduce\n",
    "morphological information in the skip-gram model in word2vec.\n",
    "\n",
    "In fastText, each central word is represented as a collection of subwords. Below we use the word \"where\" as an example to understand how subwords are formed. First, we add the special characters “&lt;” and “&gt;” at the beginning and end of the word to distinguish the subwords used as prefixes and suffixes. Then, we treat the word as a sequence of characters to extract the $n$-grams. For example, when $n=3$, we can get all subwords with a length of $3$:\n",
    "\n",
    "$$\\textrm{\"<wh\"}, \\ \\textrm{\"whe\"}, \\ \\textrm{\"her\"}, \\ \\textrm{\"ere\"}, \\ \\textrm{\"re>\"},$$\n",
    "\n",
    "and the special subword $\\textrm{\"<where>\"}$.\n",
    "\n",
    "In fastText, for a word $w$, we record the union of all its subwords with length of $3$ to $6$ and special subwords as $\\mathcal{G}_w$. Thus, the dictionary is the union of the collection of subwords of all words. Assume the vector of the subword $g$ in the dictionary is $\\mathbf{z}_g$. Then, the central word vector $\\mathbf{u}_w$ for the word $w$ in the skip-gram model can be expressed as\n",
    "\n",
    "$$\\mathbf{u}_w = \\sum_{g\\in\\mathcal{G}_w} \\mathbf{z}_g.$$\n",
    "\n",
    "The rest of the fastText process is consistent with the skip-gram model, so it is not repeated here. As we can see, compared with the skip-gram model, the dictionary in fastText is larger, resulting in more model parameters. Also, the vector of one word requires the summation of all subword vectors, which results in higher computation complexity. However, we can obtain better vectors for more uncommon complex words, even words not existing in the dictionary, by looking at other words with similar structures.\n",
    "\n",
    "\n",
    "## Byte Pair Encoding\n",
    ":label:`subsec_Byte_Pair_Encoding`\n",
    "\n",
    "In fastText, all the extracted subwords have to be of the specified lengths, such as $3$ to $6$, thus the vocabulary size cannot be predefined.\n",
    "To allow for variable-length subwords in a fixed-size vocabulary,\n",
    "we can apply a compression algorithm\n",
    "called *byte pair encoding* (BPE) to extract subwords :cite:`Sennrich.Haddow.Birch.2015`.\n",
    "\n",
    "Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word,\n",
    "such as consecutive characters of arbitrary length.\n",
    "Starting from symbols of length $1$,\n",
    "byte pair encoding iteratively merges the most frequent pair of consecutive symbols to produce new longer symbols.\n",
    "Note that for efficiency, pairs crossing word boundaries are not considered.\n",
    "In the end, we can use such symbols as subwords to segment words.\n",
    "Byte pair encoding and its variants has been used for input representations in popular natural language processing pretraining models such as GPT-2 :cite:`Radford.Wu.Child.ea.2019` and RoBERTa :cite:`Liu.Ott.Goyal.ea.2019`.\n",
    "In the following, we will illustrate how byte pair encoding works.\n",
    "\n",
    "First, we initialize the vocabulary of symbols as all the English lowercase characters, a special end-of-word symbol `'_'`, and a special unknown symbol `'[UNK]'`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/Functions.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.stream.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] symbols =\n",
    "        new String[] {\n",
    "            \"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\", \"j\", \"k\", \"l\", \"m\", \"n\", \"o\", \"p\",\n",
    "            \"q\", \"r\", \"s\", \"t\", \"u\", \"v\", \"w\", \"x\", \"y\", \"z\", \"_\", \"[UNK]\"\n",
    "        };"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "Since we do not consider symbol pairs that cross boundaries of words,\n",
    "we only need a dictionary `rawTokenFreqs` that maps words to their frequencies (number of occurrences)\n",
    "in a dataset.\n",
    "Note that the special symbol `'_'` is appended to each word so that\n",
    "we can easily recover a word sequence (e.g., \"a taller man\")\n",
    "from a sequence of output symbols ( e.g., \"a_ tall er_ man\").\n",
    "Since we start the merging process from a vocabulary of only single characters and special symbols, space is inserted between every pair of consecutive characters within each word (keys of the dictionary `tokenFreqs`).\n",
    "In other words, space is the delimiter between symbols within a word.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HashMap<String, Integer> rawTokenFreqs = new HashMap<>();\n",
    "rawTokenFreqs.put(\"fast_\", 4);\n",
    "rawTokenFreqs.put(\"faster_\", 3);\n",
    "rawTokenFreqs.put(\"tall_\", 5);\n",
    "rawTokenFreqs.put(\"taller_\", 4);\n",
    "\n",
    "HashMap<String, Integer> tokenFreqs = new HashMap<>();\n",
    "for (Map.Entry<String, Integer> e : rawTokenFreqs.entrySet()) {\n",
    "    String token = e.getKey();\n",
    "    tokenFreqs.put(String.join(\" \", token.split(\"\")), rawTokenFreqs.get(token));\n",
    "}\n",
    "\n",
    "tokenFreqs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "We define the following `getMaxFreqPair` function that \n",
    "returns the most frequent pair of consecutive symbols within a word,\n",
    "where words come from keys of the input dictionary `tokenFreqs`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<String, String> getMaxFreqPair(HashMap<String, Integer> tokenFreqs) {\n",
    "    HashMap<Pair<String, String>, Integer> pairs = new HashMap<>();\n",
    "    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {\n",
    "        // Key of 'pairs' is a tuple of two consecutive symbols\n",
    "        String token = e.getKey();\n",
    "        Integer freq = e.getValue();\n",
    "        String[] symbols = token.split(\" \");\n",
    "        for (int i = 0; i < symbols.length - 1; i++) {\n",
    "            pairs.put(\n",
    "                    new Pair<>(symbols[i], symbols[i + 1]),\n",
    "                    pairs.getOrDefault(new Pair<>(symbols[i], symbols[i + 1]), 0) + freq);\n",
    "        }\n",
    "    }\n",
    "    int max = 0; // Key of `pairs` with the max value\n",
    "    Pair<String, String> maxFreqPair = null;\n",
    "    for (Map.Entry<Pair<String, String>, Integer> pair : pairs.entrySet()) {\n",
    "        if (max < pair.getValue()) {\n",
    "            max = pair.getValue();\n",
    "            maxFreqPair = pair.getKey();\n",
    "        }\n",
    "    }\n",
    "    return maxFreqPair;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "As a greedy approach based on frequency of consecutive symbols,\n",
    "byte pair encoding will use the following `mergeSymbols` function to merge the most frequent pair of consecutive symbols to produce new symbols.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 7,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public static Pair<HashMap<String, Integer>, String[]> mergeSymbols(\n",
    "        Pair<String, String> maxFreqPair, HashMap<String, Integer> tokenFreqs) {\n",
    "    ArrayList<String> symbols = new ArrayList<>();\n",
    "    symbols.add(maxFreqPair.getKey() + maxFreqPair.getValue());\n",
    "\n",
    "    HashMap<String, Integer> newTokenFreqs = new HashMap<>();\n",
    "    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {\n",
    "        String token = e.getKey();\n",
    "        String newToken =\n",
    "                token.replace(\n",
    "                        maxFreqPair.getKey() + \" \" + maxFreqPair.getValue(),\n",
    "                        maxFreqPair.getKey() + \"\" + maxFreqPair.getValue());\n",
    "        newTokenFreqs.put(newToken, tokenFreqs.get(token));\n",
    "    }\n",
    "    return new Pair(newTokenFreqs, symbols.toArray(new String[symbols.size()]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Now we iteratively perform the byte pair encoding algorithm over the keys of the dictionary `tokenFreqs`. In the first iteration, the most frequent pair of consecutive symbols are `'t'` and `'a'`, thus byte pair encoding merges them to produce a new symbol `'ta'`. In the second iteration, byte pair encoding continues to merge `'ta'` and `'l'` to result in another new symbol `'tal'`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numMerges = 10;\n",
    "for (int i = 0; i < numMerges; i++) {\n",
    "    Pair<String, String> maxFreqPair = getMaxFreqPair(tokenFreqs);\n",
    "    Pair<HashMap<String, Integer>, String[]> pair =\n",
    "            mergeSymbols(maxFreqPair, tokenFreqs);\n",
    "    tokenFreqs = pair.getKey();\n",
    "    symbols =\n",
    "            Stream.concat(Arrays.stream(symbols), Arrays.stream(pair.getValue()))\n",
    "                    .toArray(String[]::new);\n",
    "    System.out.println(\n",
    "            \"merge #\"\n",
    "                    + (i + 1)\n",
    "                    + \": (\"\n",
    "                    + maxFreqPair.getKey()\n",
    "                    + \", \"\n",
    "                    + maxFreqPair.getValue()\n",
    "                    + \")\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "After 10 iterations of byte pair encoding, we can see that list `symbols` now contains 10 more symbols that are iteratively merged from other symbols.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Arrays.toString(symbols)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "For the same dataset specified in the keys of the dictionary `raw_token_freqs`,\n",
    "each word in the dataset is now segmented by subwords \"fast_\", \"fast\", \"er_\", \"tall_\", and \"tall\"\n",
    "as a result of the byte pair encoding algorithm.\n",
    "For instance, words \"faster_\" and \"taller_\" are segmented as \"fast er_\" and \"tall er_\", respectively.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenFreqs.keySet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "Note that the result of byte pair encoding depends on the dataset being used.\n",
    "We can also use the subwords learned from one dataset\n",
    "to segment words of another dataset.\n",
    "As a greedy approach, the following `segmentBPE` function tries to break words into the longest possible subwords from the input argument `symbols`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static List<String> segmentBPE(String[] tokens, String[] symbols) {\n",
    "    List<String> outputs = new ArrayList<>();\n",
    "    for (String token : tokens) {\n",
    "        int start = 0;\n",
    "        int end = token.length();\n",
    "        ArrayList<String> curOutput = new ArrayList<>();\n",
    "        // Segment token with the longest possible subwords from symbols\n",
    "        while (start < token.length() && start < end) {\n",
    "            if (Arrays.asList(symbols).contains(token.substring(start, end))) {\n",
    "                curOutput.add(token.substring(start, end));\n",
    "                start = end;\n",
    "                end = token.length();\n",
    "            } else {\n",
    "                end -= 1;\n",
    "            }\n",
    "        }\n",
    "        if (start < tokens.length) {\n",
    "            curOutput.add(\"[UNK]\");\n",
    "        }\n",
    "        String temp = \"\";\n",
    "        for (String s : curOutput) {\n",
    "            temp += s + \" \";\n",
    "        }\n",
    "        outputs.add(temp.trim());\n",
    "    }\n",
    "    return outputs;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "In the following, we use the subwords in list `symbols`, which is learned from the aforementioned dataset,\n",
    "to segment `tokens` that represent another dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "String[] tokens = new String[] {\"tallest_\", \"fatter_\"};\n",
    "System.out.println(segmentBPE(tokens, symbols));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "## Summary\n",
    "\n",
    "* FastText proposes a subword embedding method. Based on the skip-gram model in word2vec, it represents the central word vector as the sum of the subword vectors of the word.\n",
    "* Subword embedding utilizes the principles of morphology, which usually improves the quality of representations of uncommon words.\n",
    "* Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word. As a greedy approach, byte pair encoding iteratively merges the most frequent pair of consecutive symbols.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. When there are too many subwords (for example, 6 words in English result in about $3\\times 10^8$ combinations), what problems arise? Can you think of any methods to solve them? Hint: Refer to the end of section 3.2 of the fastText paper :cite:`Bojanowski.Grave.Joulin.ea.2017`.\n",
    "1. How can you design a subword embedding model based on the continuous bag-of-words model?\n",
    "1. To get a vocabulary of size $m$, how many merging operations are needed when the initial symbol vocabulary size is $n$?\n",
    "1. How can we extend the idea of byte pair encoding to extract phrases?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
